import gin
import torch
import torch.nn as nn
import torch.nn.functional as F

@gin.configurable
class PGD:
    def __init__(self, eps=8.0, alpha=2, steps=10, random_start=True):
        self.eps = eps / 255
        self.alpha = alpha / 255
        self.steps = steps
        self.random_start = random_start

    def attack(self, model, images, labels):
        images = images.clone().detach()
        labels = labels.clone().detach()

        loss = nn.CrossEntropyLoss()

        adv_images = images.clone().detach()

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + torch.empty_like(adv_images).uniform_(
                -self.eps, self.eps
            )
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()
        
        sign_grad_list = []

        for step in range(self.steps):
            adv_images.requires_grad = True
            outputs = model(adv_images)
            cost = loss(outputs, labels)
            grad = torch.autograd.grad(
                cost, adv_images, retain_graph=False, create_graph=False
            )[0]
            sign_grad_list.append(torch.sign(grad))

            adv_images = adv_images.detach() + self.alpha * grad.sign()
            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        adv_images.requires_grad = True
        grad = torch.autograd.grad(loss(model(adv_images), labels), adv_images, retain_graph=False, create_graph=False)[0]   
        FOSC = (8.0 / 255.0) * torch.norm(grad, p=1, dim=(1, 2, 3)) \
                - torch.einsum('ijkl,ijkl->i', adv_images - images, grad)

        flattened_sign_grad_list = [tensor.view(tensor.size(0), -1) for tensor in sign_grad_list]
        SGCS = 0   
        for i in range(self.steps):
            for j in range(i + 1, self.steps):
                SGCS += F.cosine_similarity(flattened_sign_grad_list[i], flattened_sign_grad_list[j], dim=1, eps=1e-8)
        SGCS /= (self.steps) * (self.steps - 1) / 2
        
        return adv_images, FOSC, SGCS

@gin.configurable
class PGD2:
    def __init__(self, eps=8.0, alpha=2, steps=10, random_start=True):
        self.eps = eps / 255
        self.alpha = alpha / 255
        self.steps = steps
        self.random_start = random_start

    def attack(self, model, images, labels):
        images = images.clone().detach()
        labels = labels.clone().detach()

        loss = nn.CrossEntropyLoss()

        adv_images = images.clone().detach()

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + torch.empty_like(adv_images).uniform_(
                -self.eps, self.eps
            )
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()
        
        sign_grad_list = []

        for step in range(self.steps):
            adv_images.requires_grad = True
            outputs = model(adv_images)
            cost = loss(outputs, labels)
            grad = torch.autograd.grad(
                cost, adv_images, retain_graph=False, create_graph=False
            )[0]
            sign_grad_list.append(torch.sign(grad))

            adv_images = adv_images.detach() + self.alpha * grad.sign()
            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        adv_images.requires_grad = True
        grad = torch.autograd.grad(loss(model(adv_images), labels), adv_images, retain_graph=False, create_graph=False)[0]   
        FOSC = (8.0 / 255.0) * torch.norm(grad, p=1, dim=(1, 2, 3)) \
                - torch.einsum('ijkl,ijkl->i', adv_images - images, grad)

        flattened_sign_grad_list = [tensor.view(tensor.size(0), -1) for tensor in sign_grad_list]
        SGCS = 0   
        for i in range(self.steps):
            for j in range(i + 1, self.steps):
                SGCS += F.cosine_similarity(flattened_sign_grad_list[i], flattened_sign_grad_list[j], dim=1, eps=1e-8)
        SGCS /= (self.steps) * (self.steps - 1) / 2
        
        return adv_images